import numpy as np


# from variables import num_of_band, Nu, Ny, Nv, gridN
# import time


def ChemicalPotential(energies, AtomTemperture, AtomDensityHot, **kw):
    """
    Returned mu has the same unit as energies.
    """
    gridN = np.shape(energies)[1]
    hbar = 1.0545718e-34
    kB = 1.38064852e-23
    WaveLength = 786.99e-9
    k0 = hbar * 2 * np.pi / WaveLength
    Mass = 1.443160648e-25
    Er = k0 ** 2 / (2 * Mass)
    dp = 4 * k0 ** 3 / (gridN ** 3)

    Emin = np.min(energies[0, ...])
    if 'Emin' in kw:
        Emin = kw['Emin']
    dmu = -0.005
    Lmu = 400
    mu_list = np.linspace(Emin + dmu, Emin + dmu * Lmu, Lmu)
    if 'mu_list' in kw:
        mu_list = kw['mu_list']
    ep0 = 1e20
    ns = dp / ((2 * np.pi * hbar) ** 3) * (1 / (np.exp((energies - (Emin + dmu)) / (kB * AtomTemperture / Er)) - 1))
    if np.sum(ns) < AtomDensityHot:
        print('initial value wrong, try again!')
        (mu, ns, nsp) = ChemicalPotential(energies, AtomTemperture, AtomDensityHot, Emin=Emin - dmu * Lmu)
        return mu, ns, nsp
    for mu_index in range(len(mu_list)):
        #        if np.mod(mu_index, 1) == 0:
        #            print(mu_index)
        mu = mu_list[mu_index]
        ns = dp / ((2 * np.pi * hbar) ** 3) * (1 / (np.exp((energies - mu) / (kB * AtomTemperture / Er)) - 1))
        N_Total = np.sum(ns)
        ep = np.abs(N_Total - AtomDensityHot)
        #        print(N_Total)
        if ep0 < ep:
            #            print(ep)
            #            print(N_Total)
            break
        ep0 = ep
    if mu_index == len(mu_list) - 1 and np.abs(np.sum(ns) - AtomDensityHot) / AtomDensityHot > 0.01:
        print('The loop value exceeds set value, try again!')
        (mu, ns, nsp) = ChemicalPotential(energies, AtomTemperture, AtomDensityHot, Emin=Emin + dmu * Lmu)
        return mu, ns, nsp
    mu = mu_list[mu_index - 1]
    ns = dp / ((2 * np.pi * hbar) ** 3) * (1 / (np.exp((energies - mu) / (kB * AtomTemperture / Er)) - 1))
    nsp = ns / np.sum(ns, 0)
    if np.abs(np.sum(ns) - AtomDensityHot) / AtomDensityHot > 0.01:
        # print('whut')
        mu_list_new = np.linspace(mu_list[mu_index - 2], mu_list[mu_index], 400)
        if 'Emin' in kw:
            (mu, ns, nsp) = ChemicalPotential(energies, AtomTemperture, AtomDensityHot,
                                              Emin=Emin + dmu * Lmu, mu_list=mu_list_new)
        else:
            (mu, ns, nsp) = ChemicalPotential(energies, AtomTemperture, AtomDensityHot, mu_list=mu_list_new)
        return mu, ns, nsp
    return mu, ns, nsp
